{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.586340Z",
     "start_time": "2023-03-04T15:00:20.581342Z"
    }
   },
   "outputs": [],
   "source": [
    "# 04_RotatE_AD_drug_repurposing\n",
    "#\n",
    "# created by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 4, 2023\n",
    "# updated by LuYF-Lemon-love <luyanfeng_nlp@qq.com> on March 5, 2023\n",
    "#\n",
    "# 该脚本展示了如何使用我们的预训练模型 (RotatE) 进行药物重定位 (Alzheimer's disease).\n",
    "#\n",
    "# 需要的包:\n",
    "#          csv\n",
    "#          numpy\n",
    "#          torch\n",
    "#\n",
    "# 需要的文件:\n",
    "#          ./prerequisites/infer_drug.tsv -> Drugbank 中的 FDA 批准的药物\n",
    "#          ./prerequisites/ad_drugs.txt\n",
    "#          ../../data/drkg/entities.tsv\n",
    "#          ../../data/drkg/relations.tsv\n",
    "#          ../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_entity.npy\n",
    "#          ../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_relation.npy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Alzheimer's disease Drug Repurposing via disease-compounds relations\n",
    "这个例子展示了如何使用 **我们** 的预训练模型进行药物重定位."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collecting Alzheimer's disease\n",
    "\n",
    "一开始我们需要收集 DRKG 中的 AD 列表. 我们能够使用 DRKG 中的疾病 ID 编码疾病. 下面我们将全部的 AD 疾病作为目标. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.596315Z",
     "start_time": "2023-03-04T15:00:20.588323Z"
    }
   },
   "outputs": [],
   "source": [
    "AD_disease_list = [\n",
    "'Disease::DOID:10652',\n",
    "'Disease::MESH:C536599',\n",
    "'Disease::MESH:D000544'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.613285Z",
     "start_time": "2023-03-04T15:00:20.598297Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(AD_disease_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.622233Z",
     "start_time": "2023-03-04T15:00:20.614254Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Disease::DOID:10652', 'Disease::MESH:C536599', 'Disease::MESH:D000544']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "AD_disease_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Candidate drugs\n",
    "\n",
    "现在我们使用 Drugbank 中的 FDA 批准的药物作为候选药物.（我们排除分子量 < 250 的药物）药物清单在 infer\\_drug.tsv 中."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.650158Z",
     "start_time": "2023-03-04T15:00:20.624227Z"
    }
   },
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "# Load entity file\n",
    "drug_list = []\n",
    "with open(\"./prerequisites/infer_drug.tsv\", newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['drug','ids'])\n",
    "    for row_val in reader:\n",
    "        drug_list.append(row_val['drug'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.657146Z",
     "start_time": "2023-03-04T15:00:20.651155Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8104"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(drug_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.665118Z",
     "start_time": "2023-03-04T15:00:20.659134Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Compound::DB00605', 'Compound::DB00983', 'Compound::DB01240']"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "drug_list[:3]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Treatment relation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.672100Z",
     "start_time": "2023-03-04T15:00:20.667113Z"
    }
   },
   "outputs": [],
   "source": [
    "treatment_list = [\n",
    "'DRUGBANK::treats::Compound:Disease',\n",
    "'GNBR::T::Compound:Disease',\n",
    "'Hetionet::CtD::Compound:Disease'\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.680077Z",
     "start_time": "2023-03-04T15:00:20.674095Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['DRUGBANK::treats::Compound:Disease',\n",
       " 'GNBR::T::Compound:Disease',\n",
       " 'Hetionet::CtD::Compound:Disease']"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "treatment_list"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DRKG 中 AD 的治疗药物"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.689053Z",
     "start_time": "2023-03-04T15:00:20.681075Z"
    }
   },
   "outputs": [],
   "source": [
    "ad_drugs = []\n",
    "\n",
    "with open(\"./prerequisites/ad_drugs.txt\", encoding='utf-8') as f:\n",
    "    for drug in f:\n",
    "        ad_drugs.append(drug[:-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.696036Z",
     "start_time": "2023-03-04T15:00:20.691049Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "126"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(ad_drugs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.707006Z",
     "start_time": "2023-03-04T15:00:20.697032Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Compound::DB13588',\n",
       " 'Compound::DB09130',\n",
       " 'Compound::DB00244',\n",
       " 'Compound::DB04815',\n",
       " 'Compound::DB03994',\n",
       " 'Compound::DB00712',\n",
       " 'Compound::DB00158',\n",
       " 'Compound::DB00472',\n",
       " 'Compound::DB00331',\n",
       " 'Compound::DB01593',\n",
       " 'Compound::DB11100',\n",
       " 'Compound::DB03128',\n",
       " 'Compound::DB06712',\n",
       " 'Compound::DB00993',\n",
       " 'Compound::DB00877',\n",
       " 'Compound::DB11094',\n",
       " 'Compound::DB04115',\n",
       " 'Compound::DB11805',\n",
       " 'Compound::DB00843',\n",
       " 'Compound::DB00928',\n",
       " 'Compound::DB11068',\n",
       " 'Compound::DB00694',\n",
       " 'Compound::DB11748',\n",
       " 'Compound::DB14500',\n",
       " 'Compound::DB00382',\n",
       " 'Compound::DB00393',\n",
       " 'Compound::DB00490',\n",
       " 'Compound::DB00215',\n",
       " 'Compound::DB03843',\n",
       " 'Compound::DB06756',\n",
       " 'Compound::DB05289',\n",
       " 'Compound::DB00564',\n",
       " 'Compound::DB05381',\n",
       " 'Compound::DB01235',\n",
       " 'Compound::DB11780',\n",
       " 'Compound::DB00945',\n",
       " 'Compound::DB00571',\n",
       " 'Compound::DB00674',\n",
       " 'Compound::DB01234',\n",
       " 'Compound::DB00134',\n",
       " 'Compound::DB08842',\n",
       " 'Compound::DB00136',\n",
       " 'Compound::DB00368',\n",
       " 'Compound::DB12310',\n",
       " 'Compound::DB00682',\n",
       " 'Compound::DB00459',\n",
       " 'Compound::DB03929',\n",
       " 'Compound::DB02543',\n",
       " 'Compound::DB03575',\n",
       " 'Compound::DB01050',\n",
       " 'Compound::DB01017',\n",
       " 'Compound::DB00533',\n",
       " 'Compound::DB00859',\n",
       " 'Compound::DB04703',\n",
       " 'Compound::DB15357',\n",
       " 'Compound::DB01645',\n",
       " 'Compound::DB06479',\n",
       " 'Compound::DB12052',\n",
       " 'Compound::DB03614',\n",
       " 'Compound::DB13668',\n",
       " 'Compound::DB13084',\n",
       " 'Compound::DB00422',\n",
       " 'Compound::DB00809',\n",
       " 'Compound::DB11336',\n",
       " 'Compound::DB12148',\n",
       " 'Compound::DB14086',\n",
       " 'Compound::DB06750',\n",
       " 'Compound::DB01403',\n",
       " 'Compound::DB01611',\n",
       " 'Compound::DB01200',\n",
       " 'Compound::DB02701',\n",
       " 'Compound::DB00981',\n",
       " 'Compound::DB04557',\n",
       " 'Compound::DB00746',\n",
       " 'Compound::DB14128',\n",
       " 'Compound::DB00126',\n",
       " 'Compound::DB03467',\n",
       " 'Compound::DB04422',\n",
       " 'Compound::DB01041',\n",
       " 'Compound::DB00115',\n",
       " 'Compound::DB04864',\n",
       " 'Compound::DB01104',\n",
       " 'Compound::DB15579',\n",
       " 'Compound::DB00635',\n",
       " 'Compound::DB09163',\n",
       " 'Compound::DB13178',\n",
       " 'Compound::DB09149',\n",
       " 'Compound::DB12116',\n",
       " 'Compound::DB13132',\n",
       " 'Compound::DB01796',\n",
       " 'Compound::DB00734',\n",
       " 'Compound::DB01394',\n",
       " 'Compound::DB00091',\n",
       " 'Compound::DB04892',\n",
       " 'Compound::DB00325',\n",
       " 'Compound::DB01043',\n",
       " 'Compound::DB09151',\n",
       " 'Compound::DB01037',\n",
       " 'Compound::DB01065',\n",
       " 'Compound::DB11473',\n",
       " 'Compound::DB00641',\n",
       " 'Compound::DB00328',\n",
       " 'Compound::DB01381',\n",
       " 'Compound::DB04743',\n",
       " 'Compound::DB00989',\n",
       " 'Compound::DB00337',\n",
       " 'Compound::DB09148',\n",
       " 'Compound::DB01370',\n",
       " 'Compound::DB00656',\n",
       " 'Compound::DB04660',\n",
       " 'Compound::DB04365',\n",
       " 'Compound::DB01219',\n",
       " 'Compound::DB11797',\n",
       " 'Compound::DB01698',\n",
       " 'Compound::DB00987',\n",
       " 'Compound::DB00741',\n",
       " 'Compound::DB06527',\n",
       " 'Compound::DB01750',\n",
       " 'Compound::DB11725',\n",
       " 'Compound::DB01392',\n",
       " 'Compound::DB09157',\n",
       " 'Compound::DB00122',\n",
       " 'Compound::DB11672',\n",
       " 'Compound::DB02530',\n",
       " 'Compound::DB00313',\n",
       " 'Compound::DB00207']"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ad_drugs"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get pretrained model\n",
    "\n",
    "我们能直接使用我们的预训练模型做药物重定位."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:20.712990Z",
     "start_time": "2023-03-04T15:00:20.708006Z"
    }
   },
   "outputs": [],
   "source": [
    "entity_id_file = '../../data/drkg/entities.tsv'\n",
    "relation_id_file = '../../data/drkg/relations.tsv'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Get embeddings for diseases and drugs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:21.005223Z",
     "start_time": "2023-03-04T15:00:20.713987Z"
    }
   },
   "outputs": [],
   "source": [
    "# Get drugname/disease name to entity ID mappings\n",
    "entity_map = {}\n",
    "entity_id_map = {}\n",
    "relation_map = {}\n",
    "with open(entity_id_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['id', 'name'])\n",
    "    for row_val in reader:\n",
    "        entity_map[row_val['name']] = int(row_val['id'])\n",
    "        entity_id_map[int(row_val['id'])] = row_val['name']\n",
    "        \n",
    "with open(relation_id_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['id', 'name'])\n",
    "    for row_val in reader:\n",
    "        relation_map[row_val['name']] = int(row_val['id'])\n",
    "        \n",
    "# handle the ID mapping\n",
    "drug_ids = []\n",
    "disease_ids = []\n",
    "for drug in drug_list:\n",
    "    drug_ids.append(entity_map[drug])\n",
    "    \n",
    "for disease in AD_disease_list:\n",
    "    disease_ids.append(entity_map[disease])\n",
    "\n",
    "treatment_ids = [relation_map[treat]  for treat in treatment_list]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:21.015182Z",
     "start_time": "2023-03-04T15:00:21.007203Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(3, 8104, 3)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(disease_ids),len(drug_ids),len(treatment_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:21.024158Z",
     "start_time": "2023-03-04T15:00:21.018179Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "([83488, 45467, 25842], [9475, 11010, 7486], [24, 35, 68])"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disease_ids, drug_ids[:3], treatment_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:21.742239Z",
     "start_time": "2023-03-04T15:00:21.026154Z"
    }
   },
   "outputs": [],
   "source": [
    "# Load embeddings\n",
    "import torch\n",
    "import numpy as np\n",
    "entity_emb = np.load('../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_entity.npy')\n",
    "rel_emb = np.load('../01-model/ckpts/RotatE_All_DRKG_0/All_DRKG_RotatE_relation.npy')\n",
    "\n",
    "drug_ids = torch.tensor(drug_ids).long()\n",
    "disease_ids = torch.tensor(disease_ids).long()\n",
    "treatment_ids = torch.tensor(treatment_ids)\n",
    "\n",
    "drug_emb = torch.tensor(entity_emb[drug_ids])\n",
    "treatment_embs = [torch.tensor(rel_emb[r_id]) for r_id in treatment_ids]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((97238, 400), (107, 200))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "entity_emb.shape, rel_emb.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:21.749221Z",
     "start_time": "2023-03-04T15:00:21.743236Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([83488, 45467, 25842]),\n",
       " tensor([ 9475, 11010,  7486]),\n",
       " tensor([24, 35, 68]))"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "disease_ids, drug_ids[:3], treatment_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:21.756202Z",
     "start_time": "2023-03-04T15:00:21.751216Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8104, 400])"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "drug_emb.shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Drug Repurposing Based on Edge Score\n",
    "\n",
    "我们使用下面算法来计算 the edge score. 注意, 这里我们用 log(Sigmoid) 函数使全部 scores < 0. 分数越大, $(h, r, t)$ 越可能成立.\n",
    "\n",
    "$\\mathbf{d} = \\gamma - ||\\mathbf{h}\\circ \\mathbf{r}-\\mathbf{t}||_{2}$\n",
    "\n",
    "$\\mathbf{score} = \\log\\left(\\frac{1}{1+\\exp(\\mathbf{-d})}\\right)$\n",
    "\n",
    "在进行药物重定位时，我们只使用与治疗相关的关系."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DGL-KE 官方实现 RotatE 评分函数的代码在:\n",
    "\n",
    "- https://github.com/awslabs/dgl-ke/blob/master/python/dglke/models/ke_model.py 940 - 955 行\n",
    "\n",
    "- https://github.com/awslabs/dgl-ke/blob/master/python/dglke/models/pytorch/score_fun.py 460 - 472 行\n",
    "\n",
    "OpenKE 实现 RotaE 评分函数的代码在:\n",
    "\n",
    "- https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/openke/module/model/RotatE.py\n",
    "\n",
    "其中两者的实现代码基本一样.\n",
    "\n",
    "epsilon 两者都假定值为 2.0 :\n",
    "\n",
    "- https://github.com/awslabs/dgl-ke/blob/master/python/dglke/models/ke_model.py 53 行\n",
    "\n",
    "- https://github.com/thunlp/OpenKE/blob/OpenKE-PyTorch/examples/train_rotate_WN18RR_adv.py 29 行"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:21.767212Z",
     "start_time": "2023-03-04T15:00:21.757200Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch.nn.functional as fn\n",
    "\n",
    "gamma = 18.0\n",
    "hidden_dim =  200\n",
    "epsilon = 2.0\n",
    "emb_init = (gamma + epsilon) / hidden_dim\n",
    "\n",
    "def RotatE_distance(head, rel, tail):\n",
    "    \n",
    "    re_head, im_head = torch.chunk(head, 2, dim = -1)\n",
    "    re_tail, im_tail = torch.chunk(tail, 2, dim = -1)\n",
    "    \n",
    "    phase_rel = rel / (emb_init / np.pi)\n",
    "    re_rel, im_rel = torch.cos(phase_rel), torch.sin(phase_rel)\n",
    "    \n",
    "    re_score = re_head * re_rel - im_head * im_rel\n",
    "    im_score = re_head * im_rel + im_head * re_rel\n",
    "    re_score = re_score - re_tail\n",
    "    im_score = im_score - im_tail\n",
    "    \n",
    "    score = torch.stack([re_score, im_score], dim = 0)\n",
    "    score = score.norm(dim = 0).sum(dim = -1)\n",
    "    \n",
    "    return gamma - score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:23.573346Z",
     "start_time": "2023-03-04T15:00:21.769168Z"
    }
   },
   "outputs": [],
   "source": [
    "scores_per_disease = []\n",
    "drugs = []\n",
    "for r_id in range(len(treatment_embs)):\n",
    "    treatment_emb = treatment_embs[r_id]\n",
    "    for disease_id in disease_ids:\n",
    "        disease_emb = torch.tensor(entity_emb[disease_id])\n",
    "        score = fn.logsigmoid(RotatE_distance(drug_emb, treatment_emb, disease_emb))\n",
    "        scores_per_disease.append(score)\n",
    "        drugs.append(drug_ids)\n",
    "scores = torch.cat(scores_per_disease)\n",
    "drugs = torch.cat(drugs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:23.580329Z",
     "start_time": "2023-03-04T15:00:23.574344Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([72936]), torch.Size([72936]), 72936)"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "scores.shape, drugs.shape, 3*3*8104"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:23.609251Z",
     "start_time": "2023-03-04T15:00:23.583324Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((72936,), (72936,), 72936)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# sort scores in decending order\n",
    "# torch.flip: Reverse the order of a n-D tensor along given axis in dims.\n",
    "idx = torch.flip(torch.argsort(scores), dims = [0])\n",
    "scores = scores[idx].numpy()\n",
    "drugs = drugs[idx].numpy()\n",
    "\n",
    "scores.shape, drugs.shape, 3*3*8104"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now we output proposed treatments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:23.624210Z",
     "start_time": "2023-03-04T15:00:23.610248Z"
    }
   },
   "outputs": [],
   "source": [
    "_, unique_indices = np.unique(drugs, return_index=True)\n",
    "topk = 50\n",
    "topk_indices = np.sort(unique_indices)[:topk]\n",
    "proposed_drugs = drugs[topk_indices]\n",
    "proposed_scores = scores[topk_indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2023-03-04T15:00:23.636179Z",
     "start_time": "2023-03-04T15:00:23.626206Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[9]\tCompound::DB00143\t-0.1855190545320511\n",
      "[11]\tCompound::DB00502\t-0.19982466101646423\n",
      "[13]\tCompound::DB06774\t-0.2012828290462494\n",
      "[16]\tCompound::DB04216\t-0.20350724458694458\n",
      "[17]\tCompound::DB00783\t-0.2035258710384369\n",
      "[18]\tCompound::DB09341\t-0.20749205350875854\n",
      "[20]\tCompound::DB00822\t-0.21294096112251282\n",
      "[21]\tCompound::DB00640\t-0.22600767016410828\n",
      "[23]\tCompound::DB01105\t-0.23136001825332642\n",
      "[29]\tCompound::DB00715\t-0.26189476251602173\n",
      "[31]\tCompound::DB00907\t-0.2680314779281616\n",
      "[39]\tCompound::DB01229\t-0.28694769740104675\n",
      "[41]\tCompound::DB04540\t-0.2887122929096222\n",
      "[43]\tCompound::DB01016\t-0.29204732179641724\n",
      "[44]\tCompound::DB02010\t-0.2982581555843353\n",
      "[46]\tCompound::DB14681\t-0.29944315552711487\n",
      "[48]\tCompound::DB00321\t-0.30107274651527405\n"
     ]
    }
   ],
   "source": [
    "top50_list = []\n",
    "\n",
    "for i in range(topk):\n",
    "    drug = int(proposed_drugs[i])\n",
    "    if entity_id_map[drug] not in ad_drugs:\n",
    "        score = proposed_scores[i]\n",
    "        row = \"[{}],{},{}\\n\".format(i+1, entity_id_map[drug], score)\n",
    "        top50_list.append(row)\n",
    "        print(\"[{}]\\t{}\\t{}\".format(i+1, entity_id_map[drug], score))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 输出结果到文件\n",
    "\n",
    "f = open('./results/04_rotatE_top50.csv', 'w')\n",
    "for row in top50_list:\n",
    "    f.write(row)\n",
    "f.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 在 top50 中, 原始知识图谱中存在的 AD 治疗药物列表"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1]\tCompound::DB00989\t-0.030231274664402008\n",
      "[2]\tCompound::DB00674\t-0.04483981430530548\n",
      "[3]\tCompound::DB09151\t-0.05556671693921089\n",
      "[4]\tCompound::DB00843\t-0.05607687309384346\n",
      "[5]\tCompound::DB11748\t-0.11345928907394409\n",
      "[6]\tCompound::DB13084\t-0.14315257966518402\n",
      "[7]\tCompound::DB04660\t-0.1493753343820572\n",
      "[8]\tCompound::DB00472\t-0.17104481160640717\n",
      "[10]\tCompound::DB15579\t-0.18798679113388062\n",
      "[12]\tCompound::DB03614\t-0.20075558125972748\n",
      "[14]\tCompound::DB11672\t-0.20285248756408691\n",
      "[15]\tCompound::DB12052\t-0.20329241454601288\n",
      "[19]\tCompound::DB00981\t-0.21120820939540863\n",
      "[22]\tCompound::DB06750\t-0.23092582821846008\n",
      "[24]\tCompound::DB04115\t-0.2372199445962906\n",
      "[25]\tCompound::DB14128\t-0.24131596088409424\n",
      "[26]\tCompound::DB11094\t-0.2423398494720459\n",
      "[27]\tCompound::DB01234\t-0.24693337082862854\n",
      "[28]\tCompound::DB01041\t-0.2616158127784729\n",
      "[30]\tCompound::DB01017\t-0.2666377127170563\n",
      "[32]\tCompound::DB01104\t-0.2696487605571747\n",
      "[33]\tCompound::DB12148\t-0.2770227789878845\n",
      "[34]\tCompound::DB12116\t-0.2800264358520508\n",
      "[35]\tCompound::DB03467\t-0.28069329261779785\n",
      "[36]\tCompound::DB00741\t-0.2809244394302368\n",
      "[37]\tCompound::DB00734\t-0.2819312512874603\n",
      "[38]\tCompound::DB14086\t-0.28337299823760986\n",
      "[40]\tCompound::DB11725\t-0.28758445382118225\n",
      "[42]\tCompound::DB00571\t-0.28952401876449585\n",
      "[45]\tCompound::DB00215\t-0.29838016629219055\n",
      "[47]\tCompound::DB12310\t-0.3007897138595581\n",
      "[49]\tCompound::DB04703\t-0.30450138449668884\n",
      "[50]\tCompound::DB00656\t-0.3045790195465088\n"
     ]
    }
   ],
   "source": [
    "for i in range(topk):\n",
    "    drug = int(proposed_drugs[i])\n",
    "    if entity_id_map[drug] in ad_drugs:\n",
    "        score = proposed_scores[i]\n",
    "        print(\"[{}]\\t{}\\t{}\".format(i+1, entity_id_map[drug], score))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
